from transforms import get_transforms
from torch.utils.data import DataLoader
import torchvision
import torch

import pytorch_lightning as pl
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split



class CelebADataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, train_transform, test_transform, target_type='attr'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_transform = train_transform
        self.test_transform = test_transform
        self.target_type = target_type

    def setup(self, stage: str = None):
        # Train, validation, and test sets setup
        if stage == 'fit' or stage is None:
            full_train_set = datasets.CelebA(
                self.data_dir, split='train', target_type=self.target_type, 
                transform=self.train_transform, download=True)

            train_size = int(0.9 * len(full_train_set))
            val_size = len(full_train_set) - train_size
            self.train_set, self.val_set = random_split(full_train_set, [train_size, val_size])

        if stage == 'test' or stage is None:
            self.test_set = datasets.CelebA(
                self.data_dir, split='test', target_type=self.target_type, 
                transform=self.test_transform, download=True)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False, num_workers=4)

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, train_transform, test_transform):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.train_transform = train_transform
        self.test_transform = test_transform

    def setup(self, stage: str = None):
        if stage == 'fit' or stage is None:
            full_train_set = datasets.MNIST(
                self.data_dir, train=True, download=True, transform= self.train_transform)

            self.train_set, self.val_set = random_split(full_train_set, [55000, 5000])

        if stage == 'test' or stage is None:
            self.test_set = datasets.MNIST(
                self.data_dir, train=False, download=True, transform=self.test_transform)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=4)

class SVHNDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, train_transform, test_transform):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.train_transform = train_transform
        self.test_transform = test_transform

    def setup(self, stage: str = None):
        if stage == 'fit' or stage is None:
            full_train_set = datasets.SVHN(
                self.data_dir, split='train', download=True, transform=self.train_transform) 
            
            train_size = int(0.9 * len(full_train_set)) 
            val_size = len(full_train_set) - train_size
            self.train_set, self.val_set = random_split(full_train_set, [train_size, val_size])

        if stage == 'test' or stage is None:
            self.test_set = datasets.SVHN(
                self.data_dir, split='test', download=True, transform=self.test_transform)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=4)


class CIFAR100DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, train_transform, test_transform):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.train_transform = train_transform
        self.test_transform = test_transform

    def setup(self, stage: str = None):
        if stage == 'fit' or stage is None:
            full_train_set = datasets.CIFAR100(
                self.data_dir, train=True, download=True, transform=self.train_transform) 
            
            train_size = int(0.9 * len(full_train_set)) 
            val_size = len(full_train_set) - train_size
            self.train_set, self.val_set = random_split(full_train_set, [train_size, val_size])

        if stage == 'test' or stage is None:
            self.test_set = datasets.CIFAR100(
                self.data_dir, train=False, download=True, transform=self.test_transform)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=4)

    def extract_class_samples(self, num_samples_per_class=1000):
        full_train_set = datasets.CIFAR100(self.data_dir, train=True, transform=self.train_transform, download=True)
        
        class_samples = {i: [] for i in range(100)}

        for data, target in DataLoader(full_train_set, batch_size=1, shuffle=True):
            label = target.item()
            if len(class_samples[label]) < num_samples_per_class:
                class_samples[label].append(data)
            if all(len(samples) >= num_samples_per_class for samples in class_samples.values()):
                break

        class_samples_tensor = torch.cat([torch.cat(class_samples[i], 0) for i in range(100)], 0)

        return class_samples_tensor



class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, train_transform, test_transform):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.train_transform = train_transform
        self.test_transform = test_transform

    def setup(self, stage: str = None):
        if stage == 'fit' or stage is None:
            full_train_set = datasets.CIFAR10(
                self.data_dir, train=True, download=True, transform=self.train_transform)

            self.train_set, self.val_set = random_split(full_train_set, [45000, 5000])

        if stage == 'test' or stage is None:
            self.test_set = datasets.CIFAR10(
                self.data_dir, train=False, download=True, transform=self.test_transform)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=4)
    def extract_class_samples(self, num_samples_per_class=1000):
        full_train_set = datasets.CIFAR10(self.data_dir, train=True, transform=self.train_transform, download=True)
        
        class_samples = {i: [] for i in range(10)}

        for data, target in DataLoader(full_train_set, batch_size=1, shuffle=True):
            label = target.item()
            if len(class_samples[label]) < num_samples_per_class:
                class_samples[label].append(data)
            if all(len(samples) >= num_samples_per_class for samples in class_samples.values()):
                break

        class_samples_tensor = torch.cat([torch.cat(class_samples[i], 0) for i in range(10)], 0)

        return class_samples_tensor


class carsDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.train_transform = transforms.Compose([
#            transforms.RandomHorizontalFlip(),
#            transforms.RandomCrop(96, padding=4),
            torchvision.transforms.Resize(256),  # Resize the short side of the image to 256 keeping aspect ratio
            torchvision.transforms.transforms.Resize((224,224)),  # Crop the central part of the image of the size 224x224
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.test_transform = transforms.Compose([
            torchvision.transforms.Resize(256),  # Resize the short side of the image to 256 keeping aspect ratio
            torchvision.transforms.transforms.Resize((224,224)),  # Crop the central part of the image of the size 224x224
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def setup(self, stage: str = None):
        if stage == 'fit' or stage is None:
            self.cars_train = datasets.StanfordCars(
                self.data_dir, split='train', download=True, transform=self.train_transform)

            # Splitting the train dataset into train and validation datasets
            self.cars_train, self.cars_val = random_split(self.cars_train, [6144, 2000])

        if stage == 'test' or stage is None:
            self.cars_test = datasets.StanfordCars(
                self.data_dir, split='test', download=True, transform=self.test_transform)

    def train_dataloader(self):
        return DataLoader(self.cars_train, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.cars_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.cars_test, batch_size=self.batch_size, num_workers=4)


class STL10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.train_transform = transforms.Compose([
#            transforms.RandomHorizontalFlip(),
#            transforms.RandomCrop(96, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def setup(self, stage: str = None):
        if stage == 'fit' or stage is None:
            self.stl10_train = datasets.STL10(
                self.data_dir, split='train', download=True, transform=self.train_transform)

            # Splitting the train dataset into train and validation datasets
            self.stl10_train, self.stl10_val = random_split(self.stl10_train, [4500, 500])

        if stage == 'test' or stage is None:
            self.stl10_test = datasets.STL10(
                self.data_dir, split='test', download=True, transform=self.test_transform)

    def train_dataloader(self):
        return DataLoader(self.stl10_train, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.stl10_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.stl10_test, batch_size=self.batch_size, num_workers=4)

# Usage:
# data_module = STL10DataModule(data_dir="path/to/dir", batch_size=32)
# data_module.setup('fit')


def get_datamodule(config):
    transforms = get_transforms(config['DATASET'])
    if config['DATASET'] == "CIFAR100":
        return CIFAR100DataModule(
            data_dir= config['DATAPATH'],
            batch_size=config['BATCH_SIZE'],
            train_transform=transforms["train"],
            test_transform=transforms["test"],
        )
    elif config['DATASET'] == "SVHN":
        return SVHNDataModule(
            data_dir= config['DATAPATH'],
            batch_size=config['BATCH_SIZE'],
            train_transform=transforms["train"],
            test_transform=transforms["test"],
        )
    elif config['DATASET'] == "CIFAR10":
        return CIFAR10DataModule(
            data_dir= config['DATAPATH'],
            batch_size=config['BATCH_SIZE'],
            train_transform=transforms["train"],
            test_transform=transforms["test"],
        )
    elif config['DATASET'] == "CelebA":
        return CelebADataModule(
            data_dir= config['DATAPATH'],
            batch_size=config['BATCH_SIZE'],
            train_transform=transforms["train"],
            test_transform=transforms["test"],
        )
    elif config['DATASET'] == "MNIST":
        return MNISTDataModule(
            data_dir= config['DATAPATH'],
            batch_size=config['BATCH_SIZE'],
            train_transform=transforms["train"],
            test_transform=transforms["test"],
        )
    elif config['DATASET'] == "STL10":
        print(f"BATCH_SIZE: {config['BATCH_SIZE']}")
        datamodule = STL10DataModule(data_dir = config['DATAPATH'], batch_size = config['BATCH_SIZE'])
        datamodule.setup('fit')
    elif config['DATASET'] == "cars":
        print(f"BATCH_SIZE: {config['BATCH_SIZE']}")
        datamodule = carsDataModule(data_dir = config["DATAPATH"], batch_size = config['BATCH_SIZE'])
        datamodule.setup('fit')

        return datamodule
